{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4d62708d-606c-40b3-821e-c173a6c8825e",
   "metadata": {},
   "source": [
    "# Object detection batch inference on test dataset and metrics calculation \n",
    "\n",
    "The previous notebook fine-tuned a custom Faster R-CNN model for mask detection. \n",
    "\n",
    "This notebook continues with evaluations that use a test dataset and metrics calculation to assess the model quality. Evaluations are critical for verifying that your object detection model accurately identifies objects and meets performance benchmarks, such as mean Average Precision and Intersection over Union. \n",
    "\n",
    "By running these evaluations, you can pinpoint strengths and weaknesses, ensuring the model generalizes well to new data. **Ray Data on Anyscale accelerates this process by enabling parallel batch inference across multiple GPU nodes, significantly reducing evaluation time**. This streamlined workflow allows for faster iterations and timely insights into model performance, ultimately leading to more reliable deployments.\n",
    "\n",
    "This tutorial demonstrates how to:\n",
    "\n",
    "1. **Load the fine-tuned model** from the saved weights from AWS S3 to cluster storage on Anyscale.\n",
    "2. **Process test images and annotations** using a custom VOC-format datasource.\n",
    "4. **Run batch inference** using Ray Data leveraging GPU acceleration.\n",
    "5. **Evaluate model performance** using object detection metrics (calculating mAP and mAR with TorchMetrics).\n",
    "\n",
    "Here is the overview of the pipeline:\n",
    "\n",
    "\n",
    "<img\n",
    "  src=\"https://face-masks-data.s3.us-east-2.amazonaws.com/tutorial-diagrams/batch_inference_metrics_calculation.png\"\n",
    "  alt=\"Object Detection Batch Inferece Pipeline - Metrics Calculation\"\n",
    "  style=\"width:75%;\"\n",
    "/>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2db2443e-fab7-433d-a0ff-2aa6ab5128c7",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-block alert-warning\">\n",
    "  <b>Anyscale-specific configuration</b>\n",
    "  \n",
    "  <p>Note: This tutorial is optimized for the Anyscale platform. When running on open source Ray, additional configuration is required. For example, you'll need to manually:</p>\n",
    "  \n",
    "  <ul>\n",
    "    <li>\n",
    "      <b>Configure your Ray Cluster:</b> Set up your multi-node environment, including head and worker nodes, and manage resource allocation, like autoscaling and GPU/CPU assignments, without the Anyscale automation. See <a href=\"https://docs.ray.io/en/latest/cluster/getting-started.html\">Ray Clusters</a> for details.\n",
    "    </li>\n",
    "    <li>\n",
    "      <b>Manage dependencies:</b> Install and manage dependencies on each node since you won’t have Anyscale’s Docker-based dependency management. See <a href=\"https://docs.ray.io/en/latest/ray-core/handling-dependencies.html\">Environment Dependencies</a> for instructions on installing and updating Ray in your environment.\n",
    "    </li>\n",
    "    <li>\n",
    "      <b>Set up storage:</b> Configure your own distributed or shared storage system (instead of relying on Anyscale’s integrated cluster storage). See <a href=\"https://docs.ray.io/en/latest/train/user-guides/persistent-storage.html\">Configuring Persistent Storage</a> for suggestions on setting up shared storage solutions.\n",
    "    </li>\n",
    "  </ul>\n",
    "\n",
    "</div>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76c6e64a-4f8a-4a10-80eb-32f69bf804a0",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Imports, class mappings, and visualization colors \n",
    "\n",
    "Start by importing all necessary libraries for data handling, model loading, image processing, and metrics calculation.\n",
    "\n",
    "Also define the class-to-label mapping, and its reverse, along with colors for visualizing detection results.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e73743c0-c211-4305-b9b0-690459b6e3d7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# %% \n",
    "import os\n",
    "import io\n",
    "import requests\n",
    "import numpy as np\n",
    "import torch\n",
    "from PIL import Image, ImageDraw, ImageFont\n",
    "import xmltodict\n",
    "\n",
    "import ray\n",
    "import pyarrow as pa\n",
    "from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder\n",
    "from ray.data.block import Block\n",
    "from ray.data.datasource import FileBasedDatasource\n",
    "\n",
    "from torchvision import models, transforms\n",
    "from torchvision.utils import draw_bounding_boxes\n",
    "from torchvision.transforms.functional import to_pil_image, convert_image_dtype, to_tensor\n",
    "from torchmetrics.detection.mean_ap import MeanAveragePrecision\n",
    "from functools import partial\n",
    "from ray.data import DataContext\n",
    "DataContext.get_current().enable_fallback_to_arrow_object_ext_type = True\n",
    "\n",
    "\n",
    "# Define the mapping for classes.\n",
    "CLASS_TO_LABEL = {\n",
    "    \"background\": 0,\n",
    "    \"with_mask\": 1,\n",
    "    \"without_mask\": 2,\n",
    "    \"mask_weared_incorrect\": 3\n",
    "}\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e5a47aa-452b-4fb0-9d7a-6caddcdeb2db",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Load the fine‑tuned object detection model from S3 to Anyscale cluster storage\n",
    "Load the fine‑tuned Faster R-CNN model from the previous training notebook, from  AWS S3 to Anyscale cluster storage.  \n",
    "\n",
    "### Why use cluster storage\n",
    "\n",
    "* Avoid redundant S3 reads: Multiple workers reading from S3 simultaneously can cause throttling, latency, and increased costs.\n",
    "\n",
    "* Faster model loading: Cluster storage, like shared filesystem or object store, for model weight loading is typically faster than remote S3.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "610c4493-6652-423a-b3fa-54fcb85c2f94",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "from smart_open import open as smart_open\n",
    "import os\n",
    "import torch\n",
    "from torchvision import models\n",
    "import boto3\n",
    "from botocore import UNSIGNED\n",
    "from botocore.config import Config\n",
    "\n",
    "# Paths\n",
    "remote_model_path = \"s3://face-masks-data/finetuned-models/fasterrcnn_model_mask_detection.pth\"\n",
    "cluster_model_path = \"/mnt/cluster_storage/fasterrcnn_model_mask_detection.pth\"  \n",
    "\n",
    "# Download model only once.\n",
    "if not os.path.exists(cluster_model_path):\n",
    "    # Create S3 client, falling back to unsigned for public buckets\n",
    "    session = boto3.Session()\n",
    "    # session.get_credentials() will return None if no credentials can be found.\n",
    "    if session.get_credentials():\n",
    "        # If credentials are found, use a standard signed client.\n",
    "        s3_client = session.client(\"s3\")\n",
    "    else:\n",
    "        # No credentials found, fall back to an unsigned client for public buckets.\n",
    "        s3_client = boto3.client(\"s3\", config=Config(signature_version=UNSIGNED))\n",
    "    \n",
    "    transport_params = {\"client\": s3_client}\n",
    "    \n",
    "    with smart_open(remote_model_path, \"rb\", transport_params=transport_params) as s3f, open(cluster_model_path, \"wb\") as out:\n",
    "        out.write(s3f.read())\n",
    "\n",
    "# Load the model (driver verifies it works).\n",
    "loaded_model = models.detection.fasterrcnn_resnet50_fpn(num_classes=len(CLASS_TO_LABEL))\n",
    "loaded_model.load_state_dict(torch.load(cluster_model_path, map_location=\"cpu\"))\n",
    "loaded_model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b536149c-c86b-4246-a482-44667a3343f0",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Create the test dataset using Ray Data \n",
    "Similarly to creating the training dataset in the first notebook, create your test dataset by reading the annotation files from S3 using a custom datasource and then joining the annotations with the images.\n",
    "\n",
    "In this case, because the dataset is relatively small, the S3 directory may not contain enough distinct data chunks or files to automatically create separate blocks. To improve parallelism, you can explicitly use `override_num_blocks=2`. This matches the later configuration of using 2 GPUs to process the data. \n",
    "\n",
    "\n",
    "For more details, see: \n",
    "https://docs.ray.io/en/latest/data/api/doc/ray.data.read_binary_files.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5a7b34a-51d0-4226-884b-7b0aaa42b38d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from typing import Dict\n",
    "\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "from functools import partial\n",
    "\n",
    "import os\n",
    "import ray\n",
    "\n",
    "def parse_voc_annotation(record) -> dict:\n",
    "    xml_str = record[\"bytes\"].decode(\"utf-8\")\n",
    "    if not xml_str.strip():\n",
    "        raise ValueError(\"Empty XML string\")\n",
    "        \n",
    "    annotation = xmltodict.parse(xml_str)[\"annotation\"]\n",
    "\n",
    "    # Normalize the object field to a list.\n",
    "    objects = annotation[\"object\"]\n",
    "    if isinstance(objects, dict):\n",
    "        objects = [objects]\n",
    "\n",
    "    boxes: List[Tuple] = []\n",
    "    for obj in objects:\n",
    "        x1 = float(obj[\"bndbox\"][\"xmin\"])\n",
    "        y1 = float(obj[\"bndbox\"][\"ymin\"])\n",
    "        x2 = float(obj[\"bndbox\"][\"xmax\"])\n",
    "        y2 = float(obj[\"bndbox\"][\"ymax\"])\n",
    "        boxes.append((x1, y1, x2, y2))\n",
    "\n",
    "    labels: List[int] = [CLASS_TO_LABEL[obj[\"name\"]] for obj in objects]\n",
    "    filename = annotation[\"filename\"]\n",
    "\n",
    "    return {\n",
    "        \"boxes\": np.array(boxes),\n",
    "        \"labels\": np.array(labels),\n",
    "        \"filename\": filename\n",
    "    }\n",
    "\n",
    "\n",
    "\n",
    "def read_images(images_s3_url:str, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:\n",
    "    images: List[np.ndarray] = []\n",
    "    \n",
    "    for filename in batch[\"filename\"]:\n",
    "        \n",
    "        if not filename.lower().endswith((\".png\", \".jpg\", \".jpeg\", \".bmp\", \".gif\")):\n",
    "            continue\n",
    "            \n",
    "        url = os.path.join(images_s3_url, filename)\n",
    "        response = requests.get(url)\n",
    "        image = Image.open(io.BytesIO(response.content)).convert(\"RGB\")  # Ensure image is in RGB.\n",
    "\n",
    "        images.append(np.array(image))\n",
    "    batch[\"image\"] = np.array(images, dtype=object)\n",
    "    return batch\n",
    "\n",
    "\n",
    "\n",
    "test_annotation_s3_uri = \"s3://face-masks-data/test/annotations/\"\n",
    "ds = ray.data.read_binary_files(test_annotation_s3_uri, override_num_blocks=2)\n",
    "annotations = ds.map(parse_voc_annotation)\n",
    "\n",
    "test_images_s3_url = \"https://face-masks-data.s3.us-east-2.amazonaws.com/test/images/\"\n",
    "test_read_images = partial(read_images, test_images_s3_url)\n",
    "test_dataset = annotations.map_batches(test_read_images)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec631fd8-ab3a-4521-a355-df7281fea614",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Define the batch object detection model for inference\n",
    "\n",
    "Define the `BatchObjectDetectionModel` class to encapsulate the detection logic, which you can later use with the `map_batches` function in Ray Data.\n",
    "\n",
    "Ray Data allows for two approaches when applying transformations like `map` or `map_batches`:\n",
    "\n",
    "* **Functions**: These use stateless Ray tasks, which are ideal for simple operations that don’t require loading heavyweight models.\n",
    "* **Classes**: These use stateful Ray actors, making them well-suited for more complex tasks involving heavyweight models—**exactly what you need in this case**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "270c70ff-0c11-411d-8458-3d03e4a663c9",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "class BatchObjectDetectionModel:\n",
    "    def __init__(self):\n",
    "        self.model = loaded_model\n",
    "        if torch.cuda.is_available():\n",
    "            self.model = self.model.cuda()\n",
    "\n",
    "    def __call__(self, batch: dict) -> dict:\n",
    "        predictions = []\n",
    "        for image_np in batch[\"image\"]:\n",
    "            image_tensor = torch.from_numpy(image_np).permute(2, 0, 1).float() / 255.0\n",
    "            if torch.cuda.is_available():\n",
    "                image_tensor = image_tensor.cuda()\n",
    "            with torch.no_grad():\n",
    "                pred = self.model([image_tensor])[0]\n",
    "            predictions.append({\n",
    "                \"boxes\": pred[\"boxes\"].detach().cpu().numpy(),\n",
    "                \"labels\": pred[\"labels\"].detach().cpu().numpy(),\n",
    "                \"scores\": pred[\"scores\"].detach().cpu().numpy()\n",
    "            })\n",
    "        batch[\"predictions\"] = predictions\n",
    "        return batch\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08ab25be-b2ff-423f-bbdd-0f34332bc76a",
   "metadata": {},
   "source": [
    "## Run batch inference on the Dataset\n",
    "Using Ray Data’s `map_batches`, perform batch inference with your model. \n",
    "\n",
    "Configure the process to run with a batch size of 4, concurrency of 2, and if available, 1 GPU per worker. \n",
    "\n",
    "Note that this configuration is intended solely for demonstration purposes. In real-world scenarios, you can adjust the concurrency level, GPU allocation (based on available GPUs and desired inference speed), and batch size (based on GPU memory constraints) to optimize performance.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef5b36a5-b7f0-40b6-9e09-00275f84ac64",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Use 2 concurrent actors with batch_size 4 and request 1 GPU per worker.\n",
    "# In total you are using 2 GPU nodes.\n",
    "inference_dataset = test_dataset.map_batches(\n",
    "    BatchObjectDetectionModel,\n",
    "    batch_size=4,\n",
    "    compute=ray.data.ActorPoolStrategy(size=2),\n",
    "    num_gpus=1\n",
    ")\n",
    "results = inference_dataset.take_all()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b42fbe8-ddce-4d02-9199-4ac29a44f28d",
   "metadata": {},
   "source": [
    "## Process predictions and compute evaluation metrics\n",
    "Next, convert the predictions and ground truth annotations into a format compatible with TorchMetrics. Then update the metric with these values.\n",
    "\n",
    "**Note**: You can further improve efficiency by combining the batch prediction step with the metric calculation step using a Ray Data pipeline. However, for clarity, this straightforward code illustrates the intuitive approach.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9796b7bc-6479-484f-b353-b71668fa76ef",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Prepare lists for predictions and targets.\n",
    "preds_list = []\n",
    "targets_list = []\n",
    "\n",
    "for record in results:\n",
    "    # Each record corresponds to a single image.\n",
    "    pred_dict = record[\"predictions\"]\n",
    "    # Convert predictions to tensors.\n",
    "    pred = {\n",
    "        \"boxes\": torch.as_tensor(pred_dict[\"boxes\"]),\n",
    "        \"scores\": torch.as_tensor(pred_dict[\"scores\"]),\n",
    "        \"labels\": torch.as_tensor(pred_dict[\"labels\"])\n",
    "    }\n",
    "    preds_list.append(pred)\n",
    "    \n",
    "    # Ground truth data for the image.\n",
    "    gt_boxes = record[\"boxes\"]\n",
    "    gt_labels = record[\"labels\"]\n",
    "    target = {\n",
    "        \"boxes\": torch.as_tensor(gt_boxes),\n",
    "        \"labels\": torch.as_tensor(gt_labels)\n",
    "    }\n",
    "    targets_list.append(target)\n",
    "\n",
    "# Initialize the metric.\n",
    "metric = MeanAveragePrecision()\n",
    "\n",
    "print(\"preds_list[1]:\", preds_list[1])\n",
    "print(\"targets_list[1]:\", targets_list[1])\n",
    "# Update metric with the predictions and targets.\n",
    "metric.update(preds_list, targets_list)\n",
    "\n",
    "# Compute the results.\n",
    "map_results = metric.compute()\n",
    "print(\"Mean Average Precision (mAP) results:\")\n",
    "print(map_results)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "906a52ad-26f2-43fe-837d-1a0c5ee6f2f1",
   "metadata": {},
   "source": [
    "## Evaluation metrics\n",
    "Finally, define helper functions to format and print the evaluation metrics in a clear, human-readable format.\n",
    "\n",
    "### Intersection over Union\n",
    "Intersection over Union (IoU) is a fundamental metric used in object detection to evaluate the accuracy of a predicted bounding box compared to the ground-truth bounding box. It measures the overlap between the two bounding boxes by calculating the ratio of the area of their intersection to the area of their union.\n",
    "\n",
    "### Overall Mean Average Precision (mAP)\n",
    "Mean Average Precision is the primary metric used for evaluating object detection models. It measures the average precision (AP) across different classes and IoU different IoU thresholds, for example, from 0.5 to 0.95.\n",
    "\n",
    "### Precision at specific IoU thresholds\n",
    "IoU measures the overlap between predicted and ground-truth bounding boxes.\n",
    "\n",
    "* map_50: AP when IoU = 0.50 (PASCAL VOC standard).\n",
    "* map_75: AP when IoU = 0.75 (more strict matching criteria).\n",
    "\n",
    "These values help assess how well the model performs at different levels of bounding box overlap.\n",
    "\n",
    "### Mean Average Precision (mAP) by object size\n",
    "Object detection models often perform differently based on object sizes. This section evaluates performance based on object size categories:\n",
    "\n",
    "* `map_small`: mAP for small objects. For example, tiny objects like a face in a crowd.\n",
    "* `map_medium`: mAP for medium-sized objects.\n",
    "* `map_large`: mAP for large objects.\n",
    "\n",
    "This metric helps you understand whether the model struggles with small or large objects.\n",
    "\n",
    "### Mean Average Recall (mAR) at various detection counts\n",
    "Recall measures how well the model finds all relevant objects.\n",
    "\n",
    "* `mar_1`: mAR when considering only the top 1 prediction per object.\n",
    "* `mar_10`: mAR when considering the top 10 predictions.\n",
    "* `mar_100`: mAR when considering the top 100 predictions.\n",
    "\n",
    "This metric is useful for analyzing the model’s ability to detect multiple instances of objects.\n",
    "\n",
    "### Mean Average Recall (mAR) by object size\n",
    "Similar to mAP, but focused on recall:\n",
    "\n",
    "* `mar_small`: mAR for small objects.\n",
    "* `mar_medium`: mAR for medium-sized objects.\n",
    "* `mar_large`: mAR for large objects.\n",
    "\n",
    "This metric helps you diagnose whether the model is missing detections in certain object size ranges.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad79946c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def format_tensor_value(value):\n",
    "    \"\"\"Convert a torch.Tensor to a scalar or list if necessary.\"\"\"\n",
    "    if isinstance(value, torch.Tensor):\n",
    "        # If the tensor is a scalar, extract its Python number.\n",
    "        if value.ndim == 0:\n",
    "            return value.item()\n",
    "        else:\n",
    "            # Convert non-scalar tensors to list.\n",
    "            return value.tolist()\n",
    "    return value\n",
    "\n",
    "def print_evaluation_metrics(results):\n",
    "    print(\"Evaluation Metrics Overview\")\n",
    "    print(\"=\" * 40)\n",
    "    \n",
    "    # Overall mAP\n",
    "    print(\"Overall Mean Average Precision (mAP):\")\n",
    "    print(f\"  mAP: {format_tensor_value(results['map'])}\\n\")\n",
    "    \n",
    "    # Precision at Specific IoU thresholds.\n",
    "    print(\"Precision at Specific IoU Thresholds:\")\n",
    "    print(f\"  mAP@0.50: {format_tensor_value(results['map_50'])}\")\n",
    "    print(f\"  mAP@0.75: {format_tensor_value(results['map_75'])}\\n\")\n",
    "    \n",
    "    # mAP by Object Size.\n",
    "    print(\"Mean Average Precision by Object Size:\")\n",
    "    print(f\"  Small Objects (mAP_small): {format_tensor_value(results['map_small'])}\")\n",
    "    print(f\"  Medium Objects (mAP_medium): {format_tensor_value(results['map_medium'])}\")\n",
    "    print(f\"  Large Objects (mAP_large): {format_tensor_value(results['map_large'])}\\n\")\n",
    "    \n",
    "    # MAR at Various Detection Counts.\n",
    "    print(\"Mean Average Recall (MAR) at Various Detection Counts:\")\n",
    "    print(f\"  MAR@1: {format_tensor_value(results['mar_1'])}\")\n",
    "    print(f\"  MAR@10: {format_tensor_value(results['mar_10'])}\")\n",
    "    print(f\"  MAR@100: {format_tensor_value(results['mar_100'])}\\n\")\n",
    "    \n",
    "    # MAR by Object Size.\n",
    "    print(\"Mean Average Recall by Object Size:\")\n",
    "    print(f\"  Small Objects (MAR_small): {format_tensor_value(results['mar_small'])}\")\n",
    "    print(f\"  Medium Objects (MAR_medium): {format_tensor_value(results['mar_medium'])}\")\n",
    "    print(f\"  Large Objects (MAR_large): {format_tensor_value(results['mar_large'])}\\n\")\n",
    "    \n",
    "\n",
    "\n",
    "print_evaluation_metrics(map_results)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
